{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import sys\n",
    "import numpy as np\n",
    "import pickle\n",
    "import tensorflow as tf\n",
    "import PIL\n",
    "import ipywidgets\n",
    "import io\n",
    "\n",
    "\"\"\" make sure this notebook is running from root directory \"\"\"\n",
    "while os.path.basename(os.getcwd()) in ('notebooks', 'src'):\n",
    "    os.chdir('..')\n",
    "assert ('README.md' in os.listdir('./')), 'Can not find project root, please cd to project root before running the following code'\n",
    "\n",
    "import src.tl_gan.generate_image as generate_image\n",
    "import src.tl_gan.feature_axis as feature_axis\n",
    "import src.tl_gan.feature_celeba_organize as feature_celeba_organize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" load feature directions \"\"\"\n",
    "path_feature_direction = './asset_results/pg_gan_celeba_feature_direction_40'\n",
    "\n",
    "pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1]\n",
    "\n",
    "with open(pathfile_feature_direction, 'rb') as f:\n",
    "    feature_direction_name = pickle.load(f)\n",
    "\n",
    "feature_direction = feature_direction_name['direction']\n",
    "feature_name = feature_direction_name['name']\n",
    "num_feature = feature_direction.shape[1]\n",
    "\n",
    "import importlib\n",
    "importlib.reload(feature_celeba_organize)\n",
    "feature_name = feature_celeba_organize.feature_name_celeba_rename\n",
    "feature_direction = feature_direction_name['direction']* feature_celeba_organize.feature_reverse[None, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" start tf session and load GAN model \"\"\"\n",
    "\n",
    "# path to model code and weight\n",
    "path_pg_gan_code = './src/model/pggan'\n",
    "path_model = './asset_model/karras2018iclr-celebahq-1024x1024.pkl'\n",
    "sys.path.append(path_pg_gan_code)\n",
    "\n",
    "\n",
    "\"\"\" create tf session \"\"\"\n",
    "yn_CPU_only = False\n",
    "\n",
    "if yn_CPU_only:\n",
    "    config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True)\n",
    "else:\n",
    "    config = tf.ConfigProto(allow_soft_placement=True)\n",
    "    config.gpu_options.allow_growth = True\n",
    "\n",
    "sess = tf.InteractiveSession(config=config)\n",
    "\n",
    "try:\n",
    "    with open(path_model, 'rb') as file:\n",
    "        G, D, Gs = pickle.load(file)\n",
    "except FileNotFoundError:\n",
    "    print('before running the code, download pre-trained model to project_root/asset_model/')\n",
    "    raise\n",
    "\n",
    "len_z = Gs.input_shapes[0][1]\n",
    "z_sample = np.random.randn(len_z)\n",
    "x_sample = generate_image.gen_single_img(z_sample, Gs=Gs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def img_to_bytes(x_sample):\n",
    "    imgObj = PIL.Image.fromarray(x_sample)\n",
    "    imgByteArr = io.BytesIO()\n",
    "    imgObj.save(imgByteArr, format='PNG')\n",
    "    imgBytes = imgByteArr.getvalue()\n",
    "    return imgBytes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "press +/- to adjust feature, toggle feature name to lock the feature\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "77b663b016424c9f9e7a0966aa3fd872",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(Image(value=b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x04\\x00\\x00\\x00\\x04\\x00\\x08\\x02\\x00\\x…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "z_sample = np.random.randn(len_z)\n",
    "x_sample = generate_image.gen_single_img(Gs=Gs)\n",
    "\n",
    "w_img = ipywidgets.widgets.Image(value=img_to_bytes(x_sample), fromat='png', width=512, height=512)\n",
    "\n",
    "class GuiCallback(object):\n",
    "    counter = 0\n",
    "    #     latents = z_sample\n",
    "    def __init__(self):\n",
    "        self.latents = z_sample\n",
    "        self.feature_direction = feature_direction\n",
    "        self.feature_lock_status = np.zeros(num_feature).astype('bool')\n",
    "        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(\n",
    "            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))\n",
    "\n",
    "    def random_gen(self, event):\n",
    "        self.latents = np.random.randn(len_z)\n",
    "        self.update_img()\n",
    "\n",
    "    def modify_along_feature(self, event, idx_feature, step_size=0.01):\n",
    "        self.latents += self.feature_directoion_disentangled[:, idx_feature] * step_size\n",
    "        self.update_img()\n",
    "\n",
    "    def set_feature_lock(self, event, idx_feature, set_to=None):\n",
    "        if set_to is None:\n",
    "            self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature])\n",
    "        else:\n",
    "            self.feature_lock_status[idx_feature] = set_to\n",
    "        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(\n",
    "            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))\n",
    "    \n",
    "    def update_img(self):        \n",
    "        x_sample = generate_image.gen_single_img(z=self.latents, Gs=Gs)\n",
    "        x_byte = img_to_bytes(x_sample)\n",
    "        w_img.value = x_byte\n",
    "\n",
    "guicallback = GuiCallback()\n",
    "\n",
    "step_size = 0.4\n",
    "def create_button(idx_feature, width=96, height=40):\n",
    "    \"\"\" function to built button groups for one feature \"\"\"\n",
    "    w_name_toggle = ipywidgets.widgets.ToggleButton(\n",
    "        value=False, description=feature_name[idx_feature],\n",
    "        tooltip='{}, Press down to lock this feature'.format(feature_name[idx_feature]),\n",
    "        layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), \n",
    "                                 width='{:.0f}px'.format(width),\n",
    "                                 margin='2px 2px 2px 2px')\n",
    "    )\n",
    "    w_neg = ipywidgets.widgets.Button(description='-',\n",
    "                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), \n",
    "                                                               width='{:.0f}px'.format(width/2),\n",
    "                                                               margin='1px 1px 5px 1px'))\n",
    "    w_pos = ipywidgets.widgets.Button(description='+',\n",
    "                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), \n",
    "                                                               width='{:.0f}px'.format(width/2),\n",
    "                                                               margin='1px 1px 5px 1px'))\n",
    "    \n",
    "    w_name_toggle.observe(lambda event: \n",
    "                      guicallback.set_feature_lock(event, idx_feature))\n",
    "    w_neg.on_click(lambda event: \n",
    "                     guicallback.modify_along_feature(event, idx_feature, step_size=-1 * step_size))\n",
    "    w_pos.on_click(lambda event: \n",
    "                     guicallback.modify_along_feature(event, idx_feature, step_size=+1 * step_size))\n",
    "    \n",
    "    button_group = ipywidgets.VBox([w_name_toggle, ipywidgets.HBox([w_neg, w_pos])],\n",
    "                                  layout=ipywidgets.Layout(border='1px solid gray'))\n",
    "    \n",
    "    return button_group\n",
    "  \n",
    "\n",
    "list_buttons = []\n",
    "for idx_feature in range(num_feature):\n",
    "    list_buttons.append(create_button(idx_feature))\n",
    "\n",
    "yn_button_select = True\n",
    "def arrange_buttons(list_buttons, yn_button_select=True, ncol=4):\n",
    "    num = len(list_buttons)\n",
    "    if yn_button_select:\n",
    "        feature_celeba_layout = feature_celeba_organize.feature_celeba_layout\n",
    "        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox([list_buttons[item] for item in row]) for row in feature_celeba_layout])\n",
    "    else:\n",
    "        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox(list_buttons[i*ncol:(i+1)*ncol]) for i in range(num//ncol+int(num%ncol>0))])\n",
    "    return layout_all_buttons\n",
    "    \n",
    "\n",
    "# w_button.on_click(on_button_clicked)\n",
    "guicallback.update_img()\n",
    "w_button_random = ipywidgets.widgets.Button(description='random face', button_style='success',\n",
    "                                           layout=ipywidgets.Layout(height='40px', \n",
    "                                                               width='128px',\n",
    "                                                               margin='1px 1px 5px 1px'))\n",
    "w_button_random.on_click(guicallback.random_gen)\n",
    "\n",
    "w_box = ipywidgets.HBox([w_img, \n",
    "                         ipywidgets.VBox([w_button_random, \n",
    "                                         arrange_buttons(list_buttons, yn_button_select=True)])\n",
    "                        ], layout=ipywidgets.Layout(height='1024}px', width='1024px')\n",
    "                       )\n",
    "\n",
    "print('press +/- to adjust feature, toggle feature name to lock the feature')\n",
    "display(w_box)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_tensorflow_p36)",
   "language": "python",
   "name": "conda_tensorflow_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
